{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Prepare model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class NaiveModel(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)\n",
    "        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)\n",
    "        self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)\n",
    "        self.fc2 = torch.nn.Linear(500, 10)\n",
    "        self.relu1 = torch.nn.ReLU6()\n",
    "        self.relu2 = torch.nn.ReLU6()\n",
    "        self.relu3 = torch.nn.ReLU6()\n",
    "        self.max_pool1 = torch.nn.MaxPool2d(2, 2)\n",
    "        self.max_pool2 = torch.nn.MaxPool2d(2, 2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.relu1(self.conv1(x))\n",
    "        x = self.max_pool1(x)\n",
    "        x = self.relu2(self.conv2(x))\n",
    "        x = self.max_pool2(x)\n",
    "        x = x.view(-1, x.size()[1:].numel())\n",
    "        x = self.relu3(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define model, optimizer, criterion, data_loader, trainer, evaluator.\n",
    "\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model = NaiveModel().to(device)\n",
    "\n",
    "optimizer = optim.Adadelta(model.parameters(), lr=1)\n",
    "\n",
    "criterion = torch.nn.NLLLoss()\n",
    "\n",
    "transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n",
    "train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)\n",
    "test_dataset = datasets.MNIST('./data', train=False, transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000)\n",
    "\n",
    "def trainer(model, optimizer, criterion, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % 100 == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * len(data), len(train_loader.dataset),\n",
    "                100. * batch_idx / len(train_loader), loss.item()))\n",
    "\n",
    "def evaluator(model):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item()\n",
    "            pred = output.argmax(dim=1, keepdim=True)\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    acc = 100 * correct / len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset), acc))\n",
    "\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 0 [0/60000 (0%)]\tLoss: 2.313423\n",
      "Train Epoch: 0 [6400/60000 (11%)]\tLoss: 0.091786\n",
      "Train Epoch: 0 [12800/60000 (21%)]\tLoss: 0.087317\n",
      "Train Epoch: 0 [19200/60000 (32%)]\tLoss: 0.036397\n",
      "Train Epoch: 0 [25600/60000 (43%)]\tLoss: 0.008173\n",
      "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.047565\n",
      "Train Epoch: 0 [38400/60000 (64%)]\tLoss: 0.122448\n",
      "Train Epoch: 0 [44800/60000 (75%)]\tLoss: 0.036732\n",
      "Train Epoch: 0 [51200/60000 (85%)]\tLoss: 0.150135\n",
      "Train Epoch: 0 [57600/60000 (96%)]\tLoss: 0.109684\n",
      "\n",
      "Test set: Average loss: 0.0457, Accuracy: 9857/10000 (99%)\n",
      "\n",
      "Train Epoch: 1 [0/60000 (0%)]\tLoss: 0.020650\n",
      "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.091525\n",
      "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.019602\n",
      "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.027827\n",
      "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.019414\n",
      "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.007640\n",
      "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.051296\n",
      "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.012038\n",
      "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.121057\n",
      "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.015796\n",
      "\n",
      "Test set: Average loss: 0.0302, Accuracy: 9902/10000 (99%)\n",
      "\n",
      "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.009903\n",
      "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.062256\n",
      "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.013844\n",
      "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.014133\n",
      "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.001051\n",
      "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.006128\n",
      "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.032162\n",
      "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.007687\n",
      "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.092295\n",
      "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.006266\n",
      "\n",
      "Test set: Average loss: 0.0259, Accuracy: 9920/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# pre-train model for 3 epoches.\n",
    "\n",
    "scheduler = StepLR(optimizer, step_size=1, gamma=0.7)\n",
    "\n",
    "for epoch in range(0, 3):\n",
    "    trainer(model, optimizer, criterion, epoch)\n",
    "    evaluator(model)\n",
    "    scheduler.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "op_name: \n",
      "op_type: <class '__main__.NaiveModel'>\n",
      "\n",
      "op_name: conv1\n",
      "op_type: <class 'torch.nn.modules.conv.Conv2d'>\n",
      "\n",
      "op_name: conv2\n",
      "op_type: <class 'torch.nn.modules.conv.Conv2d'>\n",
      "\n",
      "op_name: fc1\n",
      "op_type: <class 'torch.nn.modules.linear.Linear'>\n",
      "\n",
      "op_name: fc2\n",
      "op_type: <class 'torch.nn.modules.linear.Linear'>\n",
      "\n",
      "op_name: relu1\n",
      "op_type: <class 'torch.nn.modules.activation.ReLU6'>\n",
      "\n",
      "op_name: relu2\n",
      "op_type: <class 'torch.nn.modules.activation.ReLU6'>\n",
      "\n",
      "op_name: relu3\n",
      "op_type: <class 'torch.nn.modules.activation.ReLU6'>\n",
      "\n",
      "op_name: max_pool1\n",
      "op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>\n",
      "\n",
      "op_name: max_pool2\n",
      "op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[None, None, None, None, None, None, None, None, None, None]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# show all op_name and op_type in the model.\n",
    "\n",
    "[print('op_name: {}\\nop_type: {}\\n'.format(name, type(module))) for name, module in model.named_modules()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([20, 1, 5, 5])\n"
     ]
    }
   ],
   "source": [
    "# show the weight size of `conv1`.\n",
    "\n",
    "print(model.conv1.weight.data.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],\n",
      "          [-1.8796e-01, -2.9882e-01,  6.9725e-02,  2.1561e-01,  6.5688e-02],\n",
      "          [ 1.5274e-01, -9.8471e-03,  3.2303e-01,  1.3472e-03,  1.7235e-01],\n",
      "          [ 1.1804e-01,  2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],\n",
      "          [-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],\n",
      "\n",
      "\n",
      "        [[[-1.2112e-01,  7.0756e-02,  5.0446e-02,  1.5156e-01, -2.7929e-02],\n",
      "          [-1.9744e-01, -2.1336e-03,  7.2534e-02,  6.2336e-02,  1.6039e-01],\n",
      "          [-6.7510e-02,  1.4636e-01,  7.1972e-02, -8.9118e-02, -4.0895e-02],\n",
      "          [ 2.9499e-02,  2.0788e-01, -1.4989e-01,  1.1668e-01, -2.8503e-01],\n",
      "          [ 8.1894e-02, -1.4489e-01, -4.2038e-02, -1.2794e-01, -5.0379e-02]]],\n",
      "\n",
      "\n",
      "        [[[ 3.8332e-02, -1.4270e-01, -1.9585e-01,  2.2653e-01,  1.0104e-01],\n",
      "          [-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01,  2.6959e-01],\n",
      "          [ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02,  3.1572e-03],\n",
      "          [-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02,  6.6433e-02],\n",
      "          [ 8.9601e-02,  7.1189e-02, -2.4255e-01,  1.5746e-01, -1.4708e-01]]],\n",
      "\n",
      "\n",
      "        [[[-1.1963e-01, -1.7243e-01, -3.5174e-02,  1.4651e-01, -1.1675e-01],\n",
      "          [-1.3518e-01,  1.2830e-02,  7.7188e-02,  2.1060e-01,  4.0924e-02],\n",
      "          [-4.3364e-02, -1.9579e-01, -3.6559e-02, -6.9803e-02,  1.2380e-01],\n",
      "          [ 7.7321e-02,  3.7590e-02,  8.2935e-02,  2.2878e-01,  2.7859e-03],\n",
      "          [-1.3601e-01, -2.1167e-01, -2.3195e-01, -1.2524e-01,  1.0073e-01]]],\n",
      "\n",
      "\n",
      "        [[[-2.7300e-01,  6.8470e-02,  2.8405e-02, -4.5879e-03, -1.3735e-01],\n",
      "          [-8.9789e-02, -2.0209e-03,  5.0950e-03,  2.1633e-01,  2.5554e-01],\n",
      "          [ 5.4389e-02,  1.2262e-01, -1.5514e-01, -1.0416e-01,  1.3606e-01],\n",
      "          [-1.6794e-01, -2.8876e-02,  2.5900e-02, -2.4261e-02,  1.0923e-01],\n",
      "          [ 5.2524e-03, -4.4625e-02, -2.1327e-01, -1.7211e-01, -4.4819e-04]]],\n",
      "\n",
      "\n",
      "        [[[ 7.2378e-02,  1.5122e-01, -1.2964e-01,  4.9105e-02, -2.1639e-01],\n",
      "          [ 3.6547e-02, -1.5518e-02,  3.2059e-02, -3.2820e-02,  6.1231e-02],\n",
      "          [ 1.2514e-01,  8.0623e-02,  1.2686e-02, -1.0074e-01,  2.2836e-02],\n",
      "          [-2.6842e-02,  2.5578e-02, -2.5877e-01, -1.7808e-01,  7.6966e-02],\n",
      "          [-4.2424e-02,  4.7006e-02, -1.5486e-02, -4.2686e-02,  4.8482e-02]]],\n",
      "\n",
      "\n",
      "        [[[ 1.3081e-01,  9.9530e-02, -1.4729e-01, -1.7665e-01, -1.9757e-01],\n",
      "          [ 9.6603e-02,  2.2783e-02,  7.8402e-02, -2.8679e-02,  8.5252e-02],\n",
      "          [-1.5310e-02,  1.1605e-01, -5.8300e-02,  2.4563e-02,  1.7488e-01],\n",
      "          [ 6.5576e-02, -1.6325e-01, -1.1318e-01, -2.9251e-02,  6.2352e-02],\n",
      "          [-1.9084e-03, -1.4005e-01, -1.2363e-01, -9.7985e-02, -2.0562e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],\n",
      "          [-5.9877e-02,  9.8567e-02,  2.5186e-01, -1.0280e-01, -2.3416e-01],\n",
      "          [ 8.5760e-02,  1.0896e-01,  1.4898e-01,  2.1579e-01,  8.5297e-02],\n",
      "          [ 5.4720e-02, -1.7226e-01, -7.2518e-02,  6.7099e-03, -1.6011e-03],\n",
      "          [-8.9944e-02,  1.7404e-01, -3.6985e-02,  1.8602e-01,  7.2353e-02]]],\n",
      "\n",
      "\n",
      "        [[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],\n",
      "          [ 6.3310e-02,  1.7866e-01,  1.1694e-01, -1.4464e-01, -2.7711e-01],\n",
      "          [-2.4514e-02,  2.2222e-01,  2.1053e-01, -1.4271e-01,  8.7045e-02],\n",
      "          [-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],\n",
      "          [-2.4006e-02,  2.3780e-02,  1.8988e-01,  2.4734e-01,  4.8097e-02]]],\n",
      "\n",
      "\n",
      "        [[[ 1.1335e-01, -5.8451e-02,  5.2440e-02, -1.3223e-01, -2.5534e-02],\n",
      "          [ 9.1323e-02, -6.0707e-02,  2.3524e-01,  2.4992e-01,  8.7842e-02],\n",
      "          [ 2.9002e-02,  3.5379e-02, -5.9689e-02, -2.8363e-03,  1.8618e-01],\n",
      "          [-2.9671e-01,  8.1830e-03,  1.1076e-01, -5.4118e-02, -6.1685e-02],\n",
      "          [-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 1.1586e-01, -7.5997e-02, -1.4614e-01,  4.8750e-02,  1.8097e-01],\n",
      "          [-6.7027e-02, -1.4901e-01, -1.5614e-02, -1.0379e-02,  9.5526e-02],\n",
      "          [-3.2333e-02, -1.5107e-01, -1.9498e-01,  1.0083e-01,  2.2328e-01],\n",
      "          [-2.0692e-01, -6.3798e-02, -1.2524e-01,  1.9549e-01,  1.9682e-01],\n",
      "          [-2.1494e-01,  1.0475e-01, -2.4858e-02, -9.7831e-02,  1.1551e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01,  8.5433e-02],\n",
      "          [ 2.0675e-01,  3.3238e-02,  9.2437e-02,  1.1799e-01,  2.1111e-01],\n",
      "          [-5.2138e-02,  1.5790e-01,  1.8151e-01,  8.0470e-02,  1.0131e-01],\n",
      "          [-4.4786e-02,  1.1771e-01,  2.1706e-02, -1.2563e-01, -2.1142e-01],\n",
      "          [-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 1.9133e-01,  2.4711e-01,  1.0413e-01, -1.9187e-01, -3.0991e-01],\n",
      "          [-1.2382e-01,  8.3641e-03, -5.6734e-02,  5.8376e-02,  2.2880e-02],\n",
      "          [-3.1734e-01, -1.0637e-02, -5.5974e-02,  1.0676e-01, -1.1080e-02],\n",
      "          [-2.2980e-01,  2.0486e-01,  1.0147e-01,  1.4484e-01,  5.2265e-02],\n",
      "          [ 7.4410e-02,  2.2806e-02,  8.5137e-02, -2.1809e-01,  3.1704e-02]]],\n",
      "\n",
      "\n",
      "        [[[-1.1006e-01, -2.5311e-01,  1.8925e-02,  1.0399e-02,  1.1951e-01],\n",
      "          [-2.1116e-01,  1.8409e-01,  3.2172e-02,  1.5962e-01, -7.9457e-02],\n",
      "          [ 1.1059e-01,  9.1966e-02,  1.0777e-01, -9.9132e-02, -4.4586e-02],\n",
      "          [-8.7919e-02, -3.7283e-02,  9.1275e-02, -3.7412e-02,  3.8875e-02],\n",
      "          [-4.3558e-02,  1.6196e-01, -4.7944e-03, -1.7560e-02, -1.2593e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 7.6976e-02, -3.8627e-02,  1.2610e-01,  1.1994e-01,  2.1706e-03],\n",
      "          [ 7.4357e-02,  6.7929e-02,  3.1386e-02,  1.4606e-01,  2.1429e-01],\n",
      "          [-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],\n",
      "          [-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],\n",
      "          [ 1.7636e-02,  1.3744e-01, -1.0306e-01,  8.8370e-02,  7.3258e-02]]],\n",
      "\n",
      "\n",
      "        [[[ 2.0016e-01,  1.0956e-01, -5.9223e-02,  6.4871e-03, -2.4165e-01],\n",
      "          [ 5.6283e-02,  1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],\n",
      "          [ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02,  3.4367e-02],\n",
      "          [-9.1882e-02, -2.9195e-01, -8.7432e-02,  1.0144e-01, -2.0559e-02],\n",
      "          [-2.5668e-01, -9.8016e-02,  1.1103e-01, -3.0233e-02,  1.1076e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 1.0027e-03, -5.7955e-02, -2.1339e-01, -1.6729e-01, -2.0870e-01],\n",
      "          [ 4.2464e-02,  2.3177e-01, -6.1459e-02, -1.0905e-01,  1.7613e-02],\n",
      "          [-1.2282e-01,  2.1762e-01, -1.3553e-02,  2.7476e-01,  1.6703e-01],\n",
      "          [-5.6282e-02,  1.2731e-02,  1.0944e-01, -1.7347e-01,  4.4497e-02],\n",
      "          [ 5.7346e-02, -5.4657e-02,  4.8718e-02, -2.6221e-02, -2.6933e-02]]],\n",
      "\n",
      "\n",
      "        [[[ 6.7697e-02,  1.5692e-01,  2.7050e-01,  1.5936e-02,  1.7659e-01],\n",
      "          [-2.8899e-02, -1.4866e-01,  3.1838e-02,  1.0903e-01,  1.2292e-01],\n",
      "          [-1.3608e-01, -4.3198e-03, -9.8925e-02, -4.5599e-02,  1.3452e-01],\n",
      "          [-5.1435e-02, -2.3815e-01, -2.4151e-01, -4.8556e-02,  1.3825e-01],\n",
      "          [-1.2823e-01,  8.9324e-03, -1.5313e-01, -2.2933e-01, -3.4081e-02]]],\n",
      "\n",
      "\n",
      "        [[[-1.8396e-01, -6.8774e-03, -1.6675e-01,  7.1980e-03,  1.9922e-02],\n",
      "          [ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],\n",
      "          [ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],\n",
      "          [ 2.1772e-01,  8.4073e-02, -2.5264e-01, -2.1428e-01,  1.9537e-01],\n",
      "          [ 1.3124e-01,  7.9532e-02, -2.4044e-01, -1.5717e-01,  1.6562e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 1.1849e-01, -5.0517e-03, -1.8900e-01,  1.8093e-02,  6.4660e-02],\n",
      "          [-1.5309e-01, -2.0106e-01, -8.6551e-02,  5.2692e-03,  1.5448e-01],\n",
      "          [-3.0727e-01,  4.9703e-02, -4.7637e-02,  2.9111e-01, -1.3173e-01],\n",
      "          [-8.5167e-02, -1.3540e-01,  2.9235e-01,  3.7895e-03, -9.4651e-02],\n",
      "          [-6.0694e-02,  9.6936e-02,  1.0533e-01, -6.1769e-02, -1.8086e-01]]]],\n",
      "       device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# show the weight of `conv1`.\n",
    "\n",
    "print(model.conv1.weight.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Prepare config_list for pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we will prune 50% weights in `conv1`.\n",
    "\n",
    "config_list = [{\n",
    "    'sparsity': 0.5,\n",
    "    'op_types': ['Conv2d'],\n",
    "    'op_names': ['conv1']\n",
    "}]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Choose a pruner and pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use l1filter pruner to prune the model\n",
    "\n",
    "from nni.algorithms.compression.pytorch.pruning import L1FilterPruner\n",
    "\n",
    "# Note that if you use a compressor that need you to pass a optimizer,\n",
    "# you need a new optimizer instead of you have used above, because NNI might modify the optimizer.\n",
    "# And of course this modified optimizer can not be used in finetuning.\n",
    "pruner = L1FilterPruner(model, config_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "op_name: \n",
      "op_type: <class '__main__.NaiveModel'>\n",
      "\n",
      "op_name: conv1\n",
      "op_type: <class 'nni.compression.pytorch.compressor.PrunerModuleWrapper'>\n",
      "\n",
      "op_name: conv1.module\n",
      "op_type: <class 'torch.nn.modules.conv.Conv2d'>\n",
      "\n",
      "op_name: conv2\n",
      "op_type: <class 'torch.nn.modules.conv.Conv2d'>\n",
      "\n",
      "op_name: fc1\n",
      "op_type: <class 'torch.nn.modules.linear.Linear'>\n",
      "\n",
      "op_name: fc2\n",
      "op_type: <class 'torch.nn.modules.linear.Linear'>\n",
      "\n",
      "op_name: relu1\n",
      "op_type: <class 'torch.nn.modules.activation.ReLU6'>\n",
      "\n",
      "op_name: relu2\n",
      "op_type: <class 'torch.nn.modules.activation.ReLU6'>\n",
      "\n",
      "op_name: relu3\n",
      "op_type: <class 'torch.nn.modules.activation.ReLU6'>\n",
      "\n",
      "op_name: max_pool1\n",
      "op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>\n",
      "\n",
      "op_name: max_pool2\n",
      "op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[None, None, None, None, None, None, None, None, None, None, None]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# we can find the `conv1` has been wrapped, the origin `conv1` changes to `conv1.module`.\n",
    "# the weight of conv1 will modify by `weight * mask` in `forward()`. The initial mask is a `ones_like(weight)` tensor.\n",
    "\n",
    "[print('op_name: {}\\nop_type: {}\\n'.format(name, type(module))) for name, module in model.named_modules()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NaiveModel(\n",
       "  (conv1): PrunerModuleWrapper(\n",
       "    (module): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n",
       "  )\n",
       "  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))\n",
       "  (fc1): Linear(in_features=800, out_features=500, bias=True)\n",
       "  (fc2): Linear(in_features=500, out_features=10, bias=True)\n",
       "  (relu1): ReLU6()\n",
       "  (relu2): ReLU6()\n",
       "  (relu3): ReLU6()\n",
       "  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# compress the model, the mask will be updated.\n",
    "\n",
    "pruner.compress()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([20, 1, 5, 5])\n"
     ]
    }
   ],
   "source": [
    "# show the mask size of `conv1`\n",
    "\n",
    "print(model.conv1.weight_mask.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]],\n",
      "\n",
      "\n",
      "        [[[1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1., 1.]]],\n",
      "\n",
      "\n",
      "        [[[0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.],\n",
      "          [0., 0., 0., 0., 0.]]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# show the mask of `conv1`\n",
    "\n",
    "print(model.conv1.weight_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],\n",
      "          [-1.8796e-01, -2.9882e-01,  6.9725e-02,  2.1561e-01,  6.5688e-02],\n",
      "          [ 1.5274e-01, -9.8471e-03,  3.2303e-01,  1.3472e-03,  1.7235e-01],\n",
      "          [ 1.1804e-01,  2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],\n",
      "          [-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],\n",
      "\n",
      "\n",
      "        [[[-0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
      "          [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],\n",
      "          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n",
      "\n",
      "\n",
      "        [[[ 3.8332e-02, -1.4270e-01, -1.9585e-01,  2.2653e-01,  1.0104e-01],\n",
      "          [-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01,  2.6959e-01],\n",
      "          [ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02,  3.1572e-03],\n",
      "          [-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02,  6.6433e-02],\n",
      "          [ 8.9601e-02,  7.1189e-02, -2.4255e-01,  1.5746e-01, -1.4708e-01]]],\n",
      "\n",
      "\n",
      "        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00]]],\n",
      "\n",
      "\n",
      "        [[[-0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [ 0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n",
      "\n",
      "\n",
      "        [[[ 0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],\n",
      "          [ 0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00]]],\n",
      "\n",
      "\n",
      "        [[[ 0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
      "          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n",
      "\n",
      "\n",
      "        [[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],\n",
      "          [-5.9877e-02,  9.8567e-02,  2.5186e-01, -1.0280e-01, -2.3416e-01],\n",
      "          [ 8.5760e-02,  1.0896e-01,  1.4898e-01,  2.1579e-01,  8.5297e-02],\n",
      "          [ 5.4720e-02, -1.7226e-01, -7.2518e-02,  6.7099e-03, -1.6011e-03],\n",
      "          [-8.9944e-02,  1.7404e-01, -3.6985e-02,  1.8602e-01,  7.2353e-02]]],\n",
      "\n",
      "\n",
      "        [[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],\n",
      "          [ 6.3310e-02,  1.7866e-01,  1.1694e-01, -1.4464e-01, -2.7711e-01],\n",
      "          [-2.4514e-02,  2.2222e-01,  2.1053e-01, -1.4271e-01,  8.7045e-02],\n",
      "          [-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],\n",
      "          [-2.4006e-02,  2.3780e-02,  1.8988e-01,  2.4734e-01,  4.8097e-02]]],\n",
      "\n",
      "\n",
      "        [[[ 1.1335e-01, -5.8451e-02,  5.2440e-02, -1.3223e-01, -2.5534e-02],\n",
      "          [ 9.1323e-02, -6.0707e-02,  2.3524e-01,  2.4992e-01,  8.7842e-02],\n",
      "          [ 2.9002e-02,  3.5379e-02, -5.9689e-02, -2.8363e-03,  1.8618e-01],\n",
      "          [-2.9671e-01,  8.1830e-03,  1.1076e-01, -5.4118e-02, -6.1685e-02],\n",
      "          [-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00]]],\n",
      "\n",
      "\n",
      "        [[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01,  8.5433e-02],\n",
      "          [ 2.0675e-01,  3.3238e-02,  9.2437e-02,  1.1799e-01,  2.1111e-01],\n",
      "          [-5.2138e-02,  1.5790e-01,  1.8151e-01,  8.0470e-02,  1.0131e-01],\n",
      "          [-4.4786e-02,  1.1771e-01,  2.1706e-02, -1.2563e-01, -2.1142e-01],\n",
      "          [-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 1.9133e-01,  2.4711e-01,  1.0413e-01, -1.9187e-01, -3.0991e-01],\n",
      "          [-1.2382e-01,  8.3641e-03, -5.6734e-02,  5.8376e-02,  2.2880e-02],\n",
      "          [-3.1734e-01, -1.0637e-02, -5.5974e-02,  1.0676e-01, -1.1080e-02],\n",
      "          [-2.2980e-01,  2.0486e-01,  1.0147e-01,  1.4484e-01,  5.2265e-02],\n",
      "          [ 7.4410e-02,  2.2806e-02,  8.5137e-02, -2.1809e-01,  3.1704e-02]]],\n",
      "\n",
      "\n",
      "        [[[-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00],\n",
      "          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n",
      "\n",
      "\n",
      "        [[[ 7.6976e-02, -3.8627e-02,  1.2610e-01,  1.1994e-01,  2.1706e-03],\n",
      "          [ 7.4357e-02,  6.7929e-02,  3.1386e-02,  1.4606e-01,  2.1429e-01],\n",
      "          [-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],\n",
      "          [-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],\n",
      "          [ 1.7636e-02,  1.3744e-01, -1.0306e-01,  8.8370e-02,  7.3258e-02]]],\n",
      "\n",
      "\n",
      "        [[[ 2.0016e-01,  1.0956e-01, -5.9223e-02,  6.4871e-03, -2.4165e-01],\n",
      "          [ 5.6283e-02,  1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],\n",
      "          [ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02,  3.4367e-02],\n",
      "          [-9.1882e-02, -2.9195e-01, -8.7432e-02,  1.0144e-01, -2.0559e-02],\n",
      "          [-2.5668e-01, -9.8016e-02,  1.1103e-01, -3.0233e-02,  1.1076e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n",
      "          [ 0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [ 0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n",
      "\n",
      "\n",
      "        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n",
      "\n",
      "\n",
      "        [[[-1.8396e-01, -6.8774e-03, -1.6675e-01,  7.1980e-03,  1.9922e-02],\n",
      "          [ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],\n",
      "          [ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],\n",
      "          [ 2.1772e-01,  8.4073e-02, -2.5264e-01, -2.1428e-01,  1.9537e-01],\n",
      "          [ 1.3124e-01,  7.9532e-02, -2.4044e-01, -1.5717e-01,  1.6562e-01]]],\n",
      "\n",
      "\n",
      "        [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],\n",
      "          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00],\n",
      "          [-0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00]]]],\n",
      "       device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# use a dummy input to apply the sparsify.\n",
    "\n",
    "model(torch.rand(1, 1, 28, 28).to(device))\n",
    "\n",
    "# the weights of `conv1` have been sparsified.\n",
    "\n",
    "print(model.conv1.module.weight.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to pruned_naive_mnist_l1filter.pth\n",
      "[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to mask_naive_mnist_l1filter.pth\n"
     ]
    }
   ],
   "source": [
    "# export the sparsified model state to './pruned_naive_mnist_l1filter.pth'.\n",
    "# export the mask to './mask_naive_mnist_l1filter.pth'.\n",
    "\n",
    "pruner.export_model(model_path='pruned_naive_mnist_l1filter.pth', mask_path='mask_naive_mnist_l1filter.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Speed Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NaiveModel(\n",
      "  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (fc1): Linear(in_features=800, out_features=500, bias=True)\n",
      "  (fc2): Linear(in_features=500, out_features=10, bias=True)\n",
      "  (relu1): ReLU6()\n",
      "  (relu2): ReLU6()\n",
      "  (relu3): ReLU6()\n",
      "  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# If you use a wrapped model, don't forget to unwrap it.\n",
    "\n",
    "pruner._unwrap_model()\n",
    "\n",
    "# the model has been unwrapped.\n",
    "\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-1-0f2a9eb92f42>:22: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  x = x.view(-1, x.size()[1:].numel())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) start to speed up the model\n",
      "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) {'conv1': 1, 'conv2': 1}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim0 sparsity: 0.500000\n",
      "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim1 sparsity: 0.000000\n",
      "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) Dectected conv prune dim\" 0\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) infer module masks...\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::view.9\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.jit_translate/MainThread) View Module output size: [-1, 800]\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu3\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::log_softmax.10\n",
      "[2021-07-26 22:26:18] ERROR (nni.compression.pytorch.speedup.jit_translate/MainThread) aten::log_softmax is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::log_softmax.10\n",
      "[2021-07-26 22:26:18] WARNING (nni.compression.pytorch.speedup.compressor/MainThread) Note: .aten::log_softmax.10 does not have corresponding mask inference object\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc2\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu3\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::view.9\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::view.9\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv2\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv1\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) resolve the mask conflict\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace compressed modules...\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv1, op_type: Conv2d)\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu1, op_type: ReLU6)\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool1, op_type: MaxPool2d)\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv2, op_type: Conv2d)\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu2, op_type: ReLU6)\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool2, op_type: MaxPool2d)\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::view.9, op_type: aten::view) which is func type\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc1, op_type: Linear)\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 800, out_features: 500\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu3, op_type: ReLU6)\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc2, op_type: Linear)\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 500, out_features: 10\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::log_softmax.10, op_type: aten::log_softmax) which is func type\n",
      "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) speedup done\n"
     ]
    }
   ],
   "source": [
    "from nni.compression.pytorch import ModelSpeedup\n",
    "\n",
    "m_speedup = ModelSpeedup(model, dummy_input=torch.rand(10, 1, 28, 28).to(device), masks_file='mask_naive_mnist_l1filter.pth')\n",
    "m_speedup.speedup_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NaiveModel(\n",
      "  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (conv2): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (fc1): Linear(in_features=800, out_features=500, bias=True)\n",
      "  (fc2): Linear(in_features=500, out_features=10, bias=True)\n",
      "  (relu1): ReLU6()\n",
      "  (relu2): ReLU6()\n",
      "  (relu3): ReLU6()\n",
      "  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# the `conv1` has been replace from `Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))` to `Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))`\n",
    "# and the following layer `conv2` has also changed because the input channel of `conv2` should aware the output channel of `conv1`.\n",
    "\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 0 [0/60000 (0%)]\tLoss: 0.306930\n",
      "Train Epoch: 0 [6400/60000 (11%)]\tLoss: 0.045807\n",
      "Train Epoch: 0 [12800/60000 (21%)]\tLoss: 0.049293\n",
      "Train Epoch: 0 [19200/60000 (32%)]\tLoss: 0.031464\n",
      "Train Epoch: 0 [25600/60000 (43%)]\tLoss: 0.005392\n",
      "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.005652\n",
      "Train Epoch: 0 [38400/60000 (64%)]\tLoss: 0.040619\n",
      "Train Epoch: 0 [44800/60000 (75%)]\tLoss: 0.016515\n",
      "Train Epoch: 0 [51200/60000 (85%)]\tLoss: 0.092886\n",
      "Train Epoch: 0 [57600/60000 (96%)]\tLoss: 0.041380\n",
      "\n",
      "Test set: Average loss: 0.0257, Accuracy: 9917/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# finetune the model to recover the accuracy.\n",
    "\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
    "\n",
    "for epoch in range(0, 1):\n",
    "    trainer(model, optimizer, criterion, epoch)\n",
    "    evaluator(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5. Prepare config_list for quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_list = [{\n",
    "    'quant_types': ['weight', 'input'],\n",
    "    'quant_bits': {'weight': 8, 'input': 8},\n",
    "    'op_names': ['conv1', 'conv2']\n",
    "}]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 6. Choose a quantizer and quantizing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NaiveModel(\n",
       "  (conv1): QuantizerModuleWrapper(\n",
       "    (module): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n",
       "  )\n",
       "  (conv2): QuantizerModuleWrapper(\n",
       "    (module): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))\n",
       "  )\n",
       "  (fc1): Linear(in_features=800, out_features=500, bias=True)\n",
       "  (fc2): Linear(in_features=500, out_features=10, bias=True)\n",
       "  (relu1): ReLU6()\n",
       "  (relu2): ReLU6()\n",
       "  (relu3): ReLU6()\n",
       "  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer\n",
    "\n",
    "quantizer = QAT_Quantizer(model, config_list, optimizer)\n",
    "quantizer.compress()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 0 [0/60000 (0%)]\tLoss: 0.004960\n",
      "Train Epoch: 0 [6400/60000 (11%)]\tLoss: 0.036269\n",
      "Train Epoch: 0 [12800/60000 (21%)]\tLoss: 0.018744\n",
      "Train Epoch: 0 [19200/60000 (32%)]\tLoss: 0.021916\n",
      "Train Epoch: 0 [25600/60000 (43%)]\tLoss: 0.003095\n",
      "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.003947\n",
      "Train Epoch: 0 [38400/60000 (64%)]\tLoss: 0.032094\n",
      "Train Epoch: 0 [44800/60000 (75%)]\tLoss: 0.017358\n",
      "Train Epoch: 0 [51200/60000 (85%)]\tLoss: 0.083886\n",
      "Train Epoch: 0 [57600/60000 (96%)]\tLoss: 0.040433\n",
      "\n",
      "Test set: Average loss: 0.0247, Accuracy: 9917/10000 (99%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# finetune the model for calibration.\n",
    "\n",
    "for epoch in range(0, 1):\n",
    "    trainer(model, optimizer, criterion, epoch)\n",
    "    evaluator(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to quantized_naive_mnist_l1filter.pth\n",
      "[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to calibration_naive_mnist_l1filter.pth\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'conv1': {'weight_bit': 8,\n",
       "  'tracked_min_input': -0.42417848110198975,\n",
       "  'tracked_max_input': 2.8212687969207764},\n",
       " 'conv2': {'weight_bit': 8,\n",
       "  'tracked_min_input': 0.0,\n",
       "  'tracked_max_input': 4.246923446655273}}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# export the sparsified model state to './quantized_naive_mnist_l1filter.pth'.\n",
    "# export the calibration config to './calibration_naive_mnist_l1filter.pth'.\n",
    "\n",
    "quantizer.export_model(model_path='quantized_naive_mnist_l1filter.pth', calibration_path='calibration_naive_mnist_l1filter.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 7. Speed Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# speed up with tensorRT\n",
    "\n",
    "engine = ModelSpeedupTensorRT(model, (32, 1, 28, 28), config=calibration_config, batchsize=32)\n",
    "engine.compress()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
